{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5.5 卷积神经网络(LeNet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-29T13:57:37.383972Z",
     "start_time": "2019-05-29T13:57:34.520559Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0.0\n",
      "cuda\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import time\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"..\") \n",
    "import d2lzh_pytorch as d2l\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "print(torch.__version__)\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5.5.1 LeNet模型 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-29T13:57:37.394997Z",
     "start_time": "2019-05-29T13:57:37.386720Z"
    }
   },
   "outputs": [],
   "source": [
    "class LeNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(LeNet, self).__init__()\n",
    "        self.conv = nn.Sequential(\n",
    "            nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_size\n",
    "            nn.Sigmoid(),\n",
    "            nn.MaxPool2d(2, 2), # kernel_size, stride\n",
    "            nn.Conv2d(6, 16, 5),\n",
    "            nn.Sigmoid(),\n",
    "            nn.MaxPool2d(2, 2)\n",
    "        )\n",
    "        self.fc = nn.Sequential(\n",
    "            nn.Linear(16*4*4, 120),\n",
    "            nn.Sigmoid(),\n",
    "            nn.Linear(120, 84),\n",
    "            nn.Sigmoid(),\n",
    "            nn.Linear(84, 10)\n",
    "        )\n",
    "\n",
    "    def forward(self, img):\n",
    "        feature = self.conv(img)\n",
    "        output = self.fc(feature.view(img.shape[0], -1))\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-29T13:57:37.450484Z",
     "start_time": "2019-05-29T13:57:37.397357Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LeNet(\n",
      "  (conv): Sequential(\n",
      "    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))\n",
      "    (1): Sigmoid()\n",
      "    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n",
      "    (4): Sigmoid()\n",
      "    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "  )\n",
      "  (fc): Sequential(\n",
      "    (0): Linear(in_features=256, out_features=120, bias=True)\n",
      "    (1): Sigmoid()\n",
      "    (2): Linear(in_features=120, out_features=84, bias=True)\n",
      "    (3): Sigmoid()\n",
      "    (4): Linear(in_features=84, out_features=10, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "net = LeNet()\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5.5.2 获取数据和训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-29T13:57:38.432567Z",
     "start_time": "2019-05-29T13:57:37.452521Z"
    }
   },
   "outputs": [],
   "source": [
    "batch_size = 256\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-29T13:57:38.442887Z",
     "start_time": "2019-05-29T13:57:38.435111Z"
    }
   },
   "outputs": [],
   "source": [
    "# 本函数已保存在d2lzh_pytorch包中方便以后使用。该函数将被逐步改进：它的完整实现将在“图像增广”一节中描述\n",
    "def evaluate_accuracy(data_iter, net, device=None):\n",
    "    if device is None and isinstance(net, torch.nn.Module):\n",
    "        # 如果没指定device就使用net的device\n",
    "        device = list(net.parameters())[0].device\n",
    "    acc_sum, n = 0.0, 0\n",
    "    with torch.no_grad():\n",
    "        for X, y in data_iter:\n",
    "            if isinstance(net, torch.nn.Module):\n",
    "                net.eval() # 评估模式, 这会关闭dropout\n",
    "                acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()\n",
    "                net.train() # 改回训练模式\n",
    "            else: # 自定义的模型, 3.13节之后不会用到, 不考虑GPU\n",
    "                if('is_training' in net.__code__.co_varnames): # 如果有is_training这个参数\n",
    "                    # 将is_training设置成False\n",
    "                    acc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item() \n",
    "                else:\n",
    "                    acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() \n",
    "            n += y.shape[0]\n",
    "    return acc_sum / n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-29T13:57:38.453480Z",
     "start_time": "2019-05-29T13:57:38.445655Z"
    }
   },
   "outputs": [],
   "source": [
    "# 本函数已保存在d2lzh_pytorch包中方便以后使用\n",
    "def train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):\n",
    "    net = net.to(device)\n",
    "    print(\"training on \", device)\n",
    "    loss = torch.nn.CrossEntropyLoss()\n",
    "    batch_count = 0\n",
    "    for epoch in range(num_epochs):\n",
    "        train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()\n",
    "        for X, y in train_iter:\n",
    "            X = X.to(device)\n",
    "            y = y.to(device)\n",
    "            y_hat = net(X)\n",
    "            l = loss(y_hat, y)\n",
    "            optimizer.zero_grad()\n",
    "            l.backward()\n",
    "            optimizer.step()\n",
    "            train_l_sum += l.cpu().item()\n",
    "            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()\n",
    "            n += y.shape[0]\n",
    "            batch_count += 1\n",
    "        test_acc = evaluate_accuracy(test_iter, net)\n",
    "        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'\n",
    "              % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-29T13:58:00.333237Z",
     "start_time": "2019-05-29T13:57:38.456012Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training on  cuda\n",
      "epoch 1, loss 1.7885, train acc 0.337, test acc 0.584, time 2.4 sec\n",
      "epoch 2, loss 0.4793, train acc 0.614, test acc 0.666, time 2.3 sec\n",
      "epoch 3, loss 0.2637, train acc 0.704, test acc 0.720, time 2.3 sec\n",
      "epoch 4, loss 0.1747, train acc 0.734, test acc 0.740, time 2.2 sec\n",
      "epoch 5, loss 0.1282, train acc 0.751, test acc 0.749, time 2.2 sec\n"
     ]
    }
   ],
   "source": [
    "lr, num_epochs = 0.001, 5\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
